{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Supervised Fine Tuning with Gemini 2.0 Flash for change detection using the Google Gen AI SDK\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Ftuning%2Fsft_gemini_on_multiple_images.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_multiple_images.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| Author(s) |\n",
        "| --- |\n",
        "| [Ivan Nardini](https://github.com/inardini) |\n",
        "| [Erwin Huizenga](https://github.com/erwinh85) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Fine-tuning Gemini on image data using supervised learning lets you specializing Gemini models for several vision tasks including visual inspection where you want to train a model to identify specific objects or defects either within or between images.\n",
        "\n",
        "This notebook demonstrates how to fine-tune the Gemini 2.0 model for change detection task (spot differences) with the Vertex AI Supervised Tuning feature using multiple images and text as inputs.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BHX-mT4NWVsZ"
      },
      "source": [
        "### Dataset\n",
        "\n",
        "Dataset used is a modified subsample of the [Spot-the-diff dataset](https://github.com/harsh19/spot-the-diff/tree/master), introduced by Jhamtani et al. in Learning to Describe Differences Between Pairs of Similar Images.\n",
        "\n",
        "```bibtex\n",
        "@inproceedings{jhamtani2018learning,\n",
        "  title={Learning to Describe Differences Between Pairs of Similar Images},\n",
        "  author={Jhamtani, Harsh and Berg-Kirkpatrick, Taylor},\n",
        "  booktitle={Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)},\n",
        "  year={2018}\n",
        "}\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Google Gen AI SDK and other required packages\n",
        "The new Google Gen AI SDK provides a unified interface to Gemini through both the Gemini Developer API and the Gemini API on Vertex AI. With a few exceptions, code that runs on one platform will run on both. This means that you can prototype an application using the Developer API and then migrate the application to Vertex AI without rewriting your code.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --quiet google-cloud-aiplatform etils google-genai"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says \"Your session crashed for an unknown reason.\" This is expected. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NyKGtVQjgx13"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and initialize the Google Gen AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "code",
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "from google import genai\n",
        "from google.genai import types\n",
        "import vertexai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "REGION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "BUCKET_NAME = \"[your-bucket-name]\"  # @param {type:\"string\"}\n",
        "BUCKET_URI = f\"gs://{BUCKET_NAME}\"\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)\n",
        "\n",
        "client = genai.Client(vertexai=True, project=PROJECT_ID, location=REGION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sfIUgj-mU8K9"
      },
      "source": [
        "### Create a Cloud Storage bucket\n",
        "\n",
        "Only run the code below if you want to create a new Google Cloud Storage bucket."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M-L1BH8TU9Gn"
      },
      "outputs": [],
      "source": [
        "# ! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5303c05f7aa6"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fc324893334"
      },
      "outputs": [],
      "source": [
        "from io import BytesIO\n",
        "\n",
        "# General\n",
        "import json\n",
        "import random\n",
        "import time\n",
        "from typing import Any\n",
        "\n",
        "from PIL import Image\n",
        "from etils.epath import Path\n",
        "\n",
        "# For model evaluation.\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "# For model fine tuning."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fH2oJFXjOAFz"
      },
      "source": [
        "### Set constants"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IySWS37JOCJU"
      },
      "outputs": [],
      "source": [
        "INPUT_SOURCE_DATA_URI = \"gs://github-repo/tuning/data\"\n",
        "INPUT_DATA_URI = f\"{BUCKET_URI}/data\"\n",
        "MODEL_ID = \"gemini-2.0-flash-001\"  # @param {type:\"string\", isTemplate: true}\n",
        "\n",
        "SYSTEM_INSTRUCTION = \"\"\"You are an expert in \"spot the difference\" games. Your task is to analyze two images and identify their differences.\n",
        "        Instructions:\n",
        "        1. Carefully examine the two provided images.\n",
        "        2. List the differences you find between the two images in a clear and concise manner, using bullet points.  For example:\n",
        "            * \"The color of the car in image 1 is red, while in image 2 it is blue.\"\n",
        "            * \"The house in image 1 has a chimney, but the house in image 2 does not.\"\n",
        "        3. If the images are identical, output: \"no difference between two images\"\n",
        "    \"\"\".strip()\n",
        "\n",
        "TASK_PROMPT = \"Compare the two images and find differences.\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TbbbPuv2MLSY"
      },
      "source": [
        "### Set helpers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iMbvf4daTif-"
      },
      "outputs": [],
      "source": [
        "def save_jsonlines(instances: list[dict[str, Any]], file_path: str) -> None:\n",
        "    \"\"\"\n",
        "    Saves a list of JSON-serializable instances to a jsonlines file.\n",
        "    \"\"\"\n",
        "    try:\n",
        "        bucket_path = Path(file_path)\n",
        "        with bucket_path.open(\"w\") as f:\n",
        "            for i, instance in enumerate(instances):\n",
        "                try:\n",
        "                    json.dump(instance, f, ensure_ascii=False)\n",
        "                    f.write(\"\\n\")\n",
        "                except (TypeError, ValueError) as e:\n",
        "                    raise TypeError(f\"Failed to serialize instance at index {i}: {e}\")\n",
        "    except Exception as e:\n",
        "        raise OSError(f\"Failed to write to {file_path}: {e}\")\n",
        "\n",
        "\n",
        "def create_tuning_samples(file_path: str, split: str = \"train\") -> list[dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Creates tuning samples from a jsonlines file for image comparison tasks.\n",
        "    \"\"\"\n",
        "\n",
        "    bucket_path = Path(file_path)\n",
        "    instances = []\n",
        "\n",
        "    # Read and parse the jsonlines file\n",
        "    with bucket_path.open() as f:\n",
        "        content = f.read()\n",
        "        data = [\n",
        "            json.loads(line) for line in content.strip().split(\"\\n\") if line.strip()\n",
        "        ]\n",
        "\n",
        "    # Process each instance\n",
        "    for obj in data:\n",
        "        image_path = f\"{bucket_path.parent}/{split}/{obj['img_id']}\"\n",
        "        instance = {\n",
        "            \"systemInstruction\": {\n",
        "                \"role\": \"string\",\n",
        "                \"parts\": [{\"text\": SYSTEM_INSTRUCTION}],\n",
        "            },\n",
        "            \"contents\": [\n",
        "                {\n",
        "                    \"role\": \"user\",\n",
        "                    \"parts\": [\n",
        "                        {\n",
        "                            \"fileData\": {\n",
        "                                \"mimeType\": \"image/png\",\n",
        "                                \"fileUri\": f\"{image_path}.png\",\n",
        "                            }\n",
        "                        },\n",
        "                        {\"text\": \"Image 1.\"},\n",
        "                        {\n",
        "                            \"fileData\": {\n",
        "                                \"mimeType\": \"image/png\",\n",
        "                                \"fileUri\": f\"{image_path}_2.png\",\n",
        "                            }\n",
        "                        },\n",
        "                        {\"text\": \"Image 2.\"},\n",
        "                        {\"text\": TASK_PROMPT},\n",
        "                    ],\n",
        "                },\n",
        "                {\"role\": \"model\", \"parts\": [{\"text\": obj[\"sentences\"]}]},\n",
        "            ],\n",
        "        }\n",
        "        instances.append(instance)\n",
        "    return instances\n",
        "\n",
        "\n",
        "def sample_test(input_test_file_path: str) -> dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Random sample one image_id and its data from a jsonlines file.\n",
        "    \"\"\"\n",
        "    bucket_path = Path(input_test_file_path)\n",
        "\n",
        "    with bucket_path.open() as f:\n",
        "        content = f.read()\n",
        "        data = [\n",
        "            json.loads(line) for line in content.strip().split(\"\\n\") if line.strip()\n",
        "        ]\n",
        "\n",
        "    # Randomly select one instance\n",
        "    sampled_instance = random.choice(data)\n",
        "    return sampled_instance\n",
        "\n",
        "\n",
        "def plot_images_from_uri(image_one_uri: str | Path, image_two_uri: str | Path) -> None:\n",
        "    \"\"\"\n",
        "    Plot two images side by side from their URIs using etils for cloud storage support.\n",
        "    \"\"\"\n",
        "\n",
        "    def load_image(uri):\n",
        "        path = Path(uri)\n",
        "        with path.open(\"rb\") as f:\n",
        "            img_data = f.read()\n",
        "            img = Image.open(BytesIO(img_data))\n",
        "            # Convert to RGB if image is in RGBA format\n",
        "            if img.mode == \"RGBA\":\n",
        "                img = img.convert(\"RGB\")\n",
        "            return np.array(img)\n",
        "\n",
        "    # Create figure with two subplots side by side\n",
        "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n",
        "\n",
        "    # Load and display first image\n",
        "    img1 = load_image(image_one_uri)\n",
        "    ax1.imshow(img1)\n",
        "    ax1.axis(\"off\")\n",
        "    ax1.set_title(\"Image 1\")\n",
        "\n",
        "    # Load and display second image\n",
        "    img2 = load_image(image_two_uri)\n",
        "    ax2.imshow(img2)\n",
        "    ax2.axis(\"off\")\n",
        "    ax2.set_title(\"Image 2\")\n",
        "\n",
        "    # Adjust layout and display\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WZsHiu5iXwHq"
      },
      "source": [
        "## Prepare your multimodal data with multiple inputs\n",
        "\n",
        "According to the [Prepare supervised fine-tuning data for Gemini models](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare) documentation, the model tuning dataset must be in the JSON Lines (JSONL) format, where each line contains a single tuning example.\n",
        "\n",
        "With multiple multimodal inputs, you have the following JSONL structure:\n",
        "\n",
        "```\n",
        "{\n",
        "   \"systemInstruction\":{\n",
        "      \"role\":\"string\",\n",
        "      \"parts\":[\n",
        "         {\n",
        "            \"text\":\"Say something to the model.\"\n",
        "         }\n",
        "      ]\n",
        "   },\n",
        "   \"contents\":[\n",
        "      {\n",
        "         \"role\":\"user\",\n",
        "         \"parts\":[\n",
        "            {\n",
        "               \"fileData\":{\n",
        "                  \"mimeType\":\"image/png\",\n",
        "                  \"fileUri\":\"gs://path/to/image1\"\n",
        "               }\n",
        "            },\n",
        "            {\n",
        "               \"text\":\"This is the image 1\"\n",
        "            },\n",
        "            {\n",
        "               \"fileData\":{\n",
        "                  \"mimeType\":\"image/png\",\n",
        "                  \"fileUri\":\"gs://path/to/image2\"\n",
        "               }\n",
        "            },\n",
        "            {\n",
        "               \"text\":\"This is the image 2\"\n",
        "            },\n",
        "            {\n",
        "               \"text\":\"Do something with image 1 and image 2.\"\n",
        "            }\n",
        "         ]\n",
        "      },\n",
        "      {\n",
        "         \"role\":\"model\",\n",
        "         \"parts\":[\n",
        "            {\n",
        "               \"text\":\"\"\n",
        "            }\n",
        "         ]\n",
        "      }\n",
        "   ]\n",
        "}\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZUGi7ZThbChr"
      },
      "source": [
        "### Replicate the multimodal dataset in your bucket\n",
        "\n",
        "Create a copy of the tutorial dataset in your bucket."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DHdC-9nj071o"
      },
      "outputs": [],
      "source": [
        "! gsutil -m -q cp -n -r {INPUT_SOURCE_DATA_URI} {INPUT_DATA_URI}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TGOrkqV3M4UJ"
      },
      "source": [
        "### Prepare the training dataset\n",
        "\n",
        "For each sample in the original JSONL train file, prepare the tuning instance.\n",
        "Then save training instances in the new JSONL train file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-R01yFuFXz4w"
      },
      "outputs": [],
      "source": [
        "input_train_file_path = Path(INPUT_DATA_URI) / \"train.jsonl\"\n",
        "output_train_file_path = Path(INPUT_DATA_URI) / \"prepared_train.jsonl\"\n",
        "\n",
        "train_instances = create_tuning_samples(input_train_file_path)\n",
        "save_jsonlines(train_instances, output_train_file_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uo2CsadnCPqF"
      },
      "source": [
        "### Prepare the val dataset\n",
        "\n",
        "For each sample in the original JSONL val file, prepare the val instance.\n",
        "Then save val instances in the new JSONL val file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cPIQoC_qiQjy"
      },
      "outputs": [],
      "source": [
        "input_val_file_path = Path(INPUT_DATA_URI) / \"val.jsonl\"\n",
        "output_val_file_path = Path(INPUT_DATA_URI) / \"prepared_val.jsonl\"\n",
        "\n",
        "val_instances = create_tuning_samples(input_val_file_path, split=\"val\")\n",
        "save_jsonlines(train_instances, output_val_file_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uYAjjpdG_cpP"
      },
      "source": [
        "## Tune Gemini model by using supervised fine-tuning\n",
        "\n",
        "You can create a supervised fine-tuning job by using the Google Gen AI SDK for Python.\n",
        "\n",
        "While Vertex AI Tuning offers customizable hyperparameters, we recommend using the default settings for optimal performance—these are rigorously tested and ideal for initial runs. Advanced users can adjust these parameters to meet specific needs.\n",
        "\n",
        "[Check out the documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-use-supervised-tuning#tuning_hyperparameters) to learn more.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LIMn6fq6LM3r"
      },
      "source": [
        "#### Create a tuning job\n",
        "\n",
        "Start the tuning job with its default configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sr8r90AAiyoB"
      },
      "outputs": [],
      "source": [
        "train_dataset = str(Path(INPUT_DATA_URI) / \"prepared_train.jsonl\")\n",
        "validation_dataset = str(Path(INPUT_DATA_URI) / \"prepared_val.jsonl\")\n",
        "\n",
        "training_dataset = {\n",
        "    \"gcs_uri\": train_dataset,\n",
        "}\n",
        "\n",
        "validation_dataset = types.TuningValidationDataset(gcs_uri=validation_dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GToiYxswipZA"
      },
      "outputs": [],
      "source": [
        "tuned_model_display_name = \"spot-the-difference-tuning-job\"  # @param {type:\"string\"}\n",
        "\n",
        "sft_tuning_job = client.tunings.tune(\n",
        "    base_model=MODEL_ID,\n",
        "    training_dataset=training_dataset,\n",
        "    config=types.CreateTuningJobConfig(\n",
        "        adapter_size=\"ADAPTER_SIZE_EIGHT\",\n",
        "        epoch_count=1,  # set to one to keep time and cost low\n",
        "        tuned_model_display_name=tuned_model_display_name,\n",
        "    ),\n",
        ")\n",
        "sft_tuning_job"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HF3HVA8GTh8N"
      },
      "source": [
        "⚠️ It will take ~30 mins for the model tuning job to complete on the provided dataset and set configurations/hyperparameters. ⚠️"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0bGX6OjmitaR"
      },
      "outputs": [],
      "source": [
        "tuning_job = client.tunings.get(name=sft_tuning_job.name)\n",
        "tuning_job"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Whf-i9f-LTvO"
      },
      "source": [
        "#### Monitor the tuning job\n",
        "\n",
        "Check the tuning job's progress."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sqaAHUmufq8-"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "running_states = {\n",
        "    \"JOB_STATE_PENDING\",\n",
        "    \"JOB_STATE_RUNNING\",\n",
        "}\n",
        "\n",
        "while sft_tuning_job.state in running_states:\n",
        "    print(sft_tuning_job.state)\n",
        "    tuning_job = client.tunings.get(name=sft_tuning_job.name)\n",
        "    time.sleep(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "69IAIZYHLhSu"
      },
      "source": [
        "### Get some tuning job details\n",
        "\n",
        "Following a successful tuning run, retrieve the registered model's resource name and the deployed tuned model's endpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Up7_r1lWLo6i"
      },
      "outputs": [],
      "source": [
        "tuned_model_endpoint_name = tuning_job.tuned_model.endpoint\n",
        "tuning_experiment_name = sft_tuning_job.experiment\n",
        "\n",
        "print(\"Tuned model experiment\", tuning_experiment_name)\n",
        "print(\"Tuned model endpoint resource name:\", tuned_model_endpoint_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ag6Zac0HB_Q1"
      },
      "source": [
        "## Qualitatively evaluate the tuned model\n",
        "\n",
        "Assess the tuned model's performance in spotting differences between the two new images within this multimodal, multi-input context. In this simple tutorial, a qualitative analysis is sufficient.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qx5m-O3MVRGd"
      },
      "source": [
        "#### Extract a random sample\n",
        "\n",
        "Draw a random image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bK2Cyrhavw-Y"
      },
      "outputs": [],
      "source": [
        "input_test_file_path = Path(INPUT_DATA_URI) / \"test.jsonl\"\n",
        "test_sample = sample_test(input_test_file_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mmuoplaJVTxE"
      },
      "source": [
        "#### Prepare the image paths\n",
        "\n",
        "Specify the image paths for comparative analysis."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tnDiOqu0QpQf"
      },
      "outputs": [],
      "source": [
        "input_image_one_uri = Path(INPUT_DATA_URI) / \"test\" / f\"{test_sample['img_id']}.png\"\n",
        "input_image_two_uri = Path(INPUT_DATA_URI) / \"test\" / f\"{test_sample['img_id']}_2.png\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rqosop77ZIdJ"
      },
      "source": [
        "### Plot image samples\n",
        "\n",
        "Look at the sampled images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1UiaDg_3ZKuN"
      },
      "outputs": [],
      "source": [
        "plot_images_from_uri(input_image_one_uri, input_image_two_uri)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZAnOFlhRVYZ-"
      },
      "source": [
        "#### Generate the prediction\n",
        "\n",
        "Make a prediction with the tuned model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bp4yHwjNJbLQ"
      },
      "outputs": [],
      "source": [
        "tuned_model = tuning_job.tuned_model.endpoint\n",
        "\n",
        "contents = [\n",
        "    \"Image 1:\",\n",
        "    types.Part.from_uri(file_uri=str(input_image_one_uri), mime_type=\"image/jpeg\"),\n",
        "    \"Image 2:\",\n",
        "    types.Part.from_uri(file_uri=str(input_image_two_uri), mime_type=\"image/jpeg\"),\n",
        "]\n",
        "\n",
        "response = client.models.generate_content(\n",
        "    model=tuned_model,\n",
        "    contents=contents,\n",
        "    config={\n",
        "        \"temperature\": 0,\n",
        "    },\n",
        ")\n",
        "\n",
        "response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qGvaDslGVdIC"
      },
      "source": [
        "#### Compare the ground truth with the new answer from tuned model\n",
        "\n",
        "Assess the accuracy of the tuned model's response by comparing it to the ground truth."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yPXoHT_dTaHt"
      },
      "outputs": [],
      "source": [
        "print(\"Ground truth answer:\", test_sample[\"sentences\"])\n",
        "print(\"Generated answer with fine-tuned model:\", response.text.strip())"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "sft_gemini_on_multiple_images.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
